#Script to create a velocity map (moment 1 map) from a known "good" line, in
#this case one of the stacked CH3OH transitions, then use that velocity map to
#shift-and-stack all of the spectra in the cube.
#The main functionality is spectral-cube's stacking function:
#https://github.com/radio-astro-tools/spectral-cube/blob/master/spectral_cube/analysis_utilities.py#L136
#Adapted from script from Ginsburg et al. (2018)

from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
import numpy as np
import os
import spectral_cube.analysis_utilities
from spectral_cube import SpectralCube
from astropy import units as u
from astropy.io import fits
import pylab as pl
import regions
import reproject
import pickle
from  radio_beam import *
pi=3.14159
kb=1.36e-16
hp=6.626e-27
c=2.99792458e10
d=390.0
pc=3.09e18
jy=1.0e-23
B_0=272912.6e6

def jy_to_k_aperture(bmaj=None,bmin=None,freq=None):
    import math
    omega=(pi*bmaj*bmin)*(pi/180.0)**2
    wave=c/(freq)
    jy2k=wave**2*jy/(2.0*kb*omega)
    return jy2k

def create_circular_mask2(nx, ny, center=None, radius=None):
   x = np.linspace(0,nx-1, nx)
   y = np.linspace(0,ny-1, ny)
   x, y = np.meshgrid(x, y)
   mask = np.sqrt((x - center[0])**2 + (y-center[1])**2)
   inverse_mask = np.sqrt((x - center[0])**2 + (y-center[1])**2)
   for x in range(0,nx):
        for y in range(0,ny):
                if mask[x,y] < radius:
                        mask[x,y] = 1.00
                        inverse_mask[x,y] = 0
                elif mask[x,y] >= radius:
                        mask[x,y] = 0
                        inverse_mask[x,y] = 1.0
   return mask,inverse_mask

def rms(x, axis=None):
    return np.sqrt(np.mean(x**2, axis=axis))

spectral_dict={}

spectral_dict['CH3OH_241']={'filename': 'V883_Ori_SB_CH3OH-241.846GHz-fullcube_robust_2.0.image.fits'}
spectral_dict['HDO_241']={'filename': 'V883_Ori_SB_HDO-241.558GHz_robust_2.0.image.fits'}
spectral_dict['HDO_225']={'filename': 'V883_Ori_SB_HDO-225.535GHz_robust_2.0.image.fits'}
spectral_dict['H218O_203']={'filename': 'V883_Ori_SB_H218O-203.GHz-0.4kms_robust_2.0.image.fits'}



radecstring='05h38m18.100454s -07d02m25.99340s'
coord=SkyCoord([radecstring],frame='icrs',unit=(u.hourangle,u.deg))

# step 1: create a velocity map

vmap_name = 'disk_velocity_map.fits'
if not os.path.exists(vmap_name):
    cube = SpectralCube.read('V883_Ori_SB_CH3OH_stacked_robust_0.5.image.fits')
    cube2 = cube.with_spectral_unit(u.km / u.s, velocity_convention='radio',rest_value=241.806524 * u.GHz) 
    m1 = cube2.moment1()
    m0 = cube2.moment0()
    mask = m0.value > 0.004

    vmap = m1
    vmap[~mask] = np.nan

    r =regions.Regions.read('CH3OH_enclosing_region.reg')[0]
    rp = r.to_pixel(vmap.wcs)
    mask = rp.to_mask()

    vmap_ = np.empty(vmap.shape)*np.nan
    vmap_[mask.bbox.slices] = vmap[mask.bbox.slices].value * mask.data
    hdu = vmap.hdu
    hdu.data = vmap_
    hdu.writeto(vmap_name, overwrite=True)
else:
    hdu = fits.open(vmap_name)[0]
vmap = spectral_cube.lower_dimensional_structures.Projection.from_hdu(hdu)


# step 2: stack

for key in spectral_dict.keys():
   filename=spectral_dict[key]['filename']

   image=fits.open(filename)
   header=image[0].header
   radius=0.4
   radius_px=0.4/(header['CDELT2']*3600.0)
   npix_beam=3.14159*header['BMAJ']*header['BMIN']/(4.0*np.log(2.0))/header['CDELT2']**2
   npix=radius_px**2*3.14
   nbeam_per_spectrum=3.14159*radius_px**2/npix_beam


   # load the cube
   fullcube = (SpectralCube.read(filename, use_dask=False))
   freq=fullcube.spectral_axis.mean()

   print(fullcube.spectral_axis.mean())
   # convert the cube to velocity units with an arbitrary reference point
   # (this step assumes the cube is in frequency or wavelength; if the
   # cube is not, it should be skipped)
   fullcube = fullcube.with_spectral_unit(u.km/u.s,
                                          velocity_convention='radio',
                                          rest_value=fullcube.spectral_axis.mean())
   fullcube.allow_huge_operations=True
     
   fullcube = fullcube.to(u.Jy/u.arcsec**2)*(0.4**2*pi)/1.6 #the 1.6 is a fudge factor to that the fluxes are more similar to the unstacked data
                                                            #unclear where this factor is originating from, but line stacking might be biasing the
                                                            #flux upward


   # reproject the velocity map into the cube's coordinate system
   vmap_proj,_ = reproject.reproject_interp(vmap.hdu,
                                            fullcube.wcs.celestial,

                                            shape_out=fullcube.shape[1:])
   vmap_proj = u.Quantity(vmap_proj, u.km/u.s)

   # perform the stacking!
   stack = spectral_cube.analysis_utilities.stack_spectra(fullcube, vmap_proj,
                                                          v0=4.25*u.km/u.s)
   fstack = stack.with_spectral_unit(u.GHz)

   fstack.write('stacked_'+filename,
                overwrite=True)

   stack.write('stacked_cube_'+filename,
                overwrite=True)

   pl.clf()
   fstack.quicklook(filename='stacked_'+filename.replace('.fits','.png'))
   spectral_dict[key]['spectrum']=fstack.value
   spectral_dict[key]['spectrum_Jy_beam']=fstack.value
   spectral_dict[key]['freq_axis']=fstack.spectral_axis.value
   spectral_dict[key]['nbeams']=nbeam_per_spectrum
   spectral_dict[key]['npix_per_beam']=npix_beam
   spectral_dict[key]['npix_per_spectra']=npix
   nchans_stacked=fstack.spectral_axis.value.size
   nchans=header['NAXIS3']
   diffchans=int((nchans_stacked-nchans)/2.0)
   wcs=WCS(header)
   xy=wcs.world_to_pixel(coord,225.0*u.Hz)
   x=xy[0][0]
   y=xy[0][0]
   nx=header['NAXIS1']
   ny=header['NAXIS2'] 
   e_spectrum=np.zeros(nchans_stacked)
   e_spectrum_beam=np.zeros(nchans_stacked)
   mask,inverse_mask=create_circular_mask2(nx,ny, center=[x,y], radius=radius_px)
   mask=mask.transpose()
   inverse_mask=inverse_mask.transpose()
   for i in range(nchans):
      indices=(inverse_mask==1).nonzero()
      e_spectrum[i+diffchans]=rms(image[0].data[i,indices[0],indices[1]])*nbeam_per_spectrum**0.5/npix_beam
   spectral_dict[key]['e_spectrum']=e_spectrum
   for i in range(len(e_spectrum)):
      if e_spectrum[i] == 0.0:
         e_spectrum[i] = np.median(e_spectrum)



pickle.dump(spectral_dict, open("spectral_dict_stacked.pickle", "wb"))

spectral_dict=pickle.load(open("spectral_dict_stacked.pickle", "rb"))
vlsr=0.0
for key in list(spectral_dict.keys()):
   spectral_dict[key]['freq_axis']=spectral_dict[key]['freq_axis']+vlsr/3.0e5*np.median(spectral_dict[key]['freq_axis'])
   outfile=open(key+'_stacked_text_spectra.txt','w')
   for i in range(len(spectral_dict[key]['freq_axis'])):
      line='{} {} {} \n'.format(spectral_dict[key]['freq_axis'][i]*1.0e3,spectral_dict[key]['spectrum'][i],spectral_dict[key]['e_spectrum'][i])
      outfile.writelines(line)
   outfile.close()


